{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "from numpy import *\n",
    "from kNN import classify0\n",
    "\n",
    "def img2vector(filename):\n",
    "    \"\"\"\n",
    "    将图片转换成向量\n",
    "    :param filename: 文件名称\n",
    "    :return: 图片向量\n",
    "    \"\"\"\n",
    "    res = zeros((1, 1024))\n",
    "    fr = open(filename)\n",
    "    for i in range(32):\n",
    "        line = fr.readline()\n",
    "        for j in range(32):\n",
    "            res[0, 32 * i + j] = int(line[j])\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  1.,\n",
       "        1.,  1.,  1.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,\n",
       "        0.,  0.,  0.,  0.,  0.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# test\n",
    "test_vt = img2vector('./dataset/trainingDigits/0_5.txt')\n",
    "test_vt[0,0:31]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "from os import listdir\n",
    "\n",
    "def handwritingClassTest():\n",
    "    print(\"come backing...\")\n",
    "    hwLabels = []\n",
    "    trainingFileList = listdir('./dataset/trainingDigits')  # 返回目录下的所有文件\n",
    "    m = len(trainingFileList)  # 文件个数\n",
    "    trainingMat = zeros((m,1024))\n",
    "    \n",
    "    # 获取所有训练集数据\n",
    "    for i in range (m):\n",
    "        filename = trainingFileList[i]\n",
    "        file_str = filename.split('.')[0]\n",
    "        classNumber_str = int(file_str.split('_')[0])\n",
    "        hwLabels.append(classNumber_str)\n",
    "        trainingMat[i,:] = img2vector('./dataset/trainingDigits/%s' % filename)\n",
    "    \n",
    "    testFileList = listdir('./dataset/testDigits')\n",
    "    errorCount = 0\n",
    "    mTest = len(testFileList)  # 测试集个数\n",
    "    \n",
    "    for i in range(mTest):\n",
    "        filename = testFileList[i]\n",
    "        file_str = filename.split('.')[0]\n",
    "        classNumber_str = int(file_str.split('_')[0])\n",
    "        vt_underTest = img2vector('./dataset/testDigits/%s' % filename)\n",
    "        classificationResult = classify0(vt_underTest, trainingMat, hwLabels, 3)\n",
    "        # print(\"come back with: %d, real anser is: %d\" % (classificationResult, classNumber_str))\n",
    "        # 进行预测\n",
    "        if classificationResult != classNumber_str: errorCount += 1.0\n",
    "    print(\"the number of errors is: %d.\" % errorCount)\n",
    "    print(\"the total error rate is: %f\" % (errorCount/float(mTest)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "come backing...\n",
      "the number of errors is: 10.\n",
      "the total error rate is: 0.010571\n"
     ]
    }
   ],
   "source": [
    "handwritingClassTest()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}